import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import math

# ============================================================
# CNVS EXTREME MONTE CARLO TEST 7
# K_adv + Shannon residual entropy + fragment inference
# ============================================================

rng = np.random.default_rng(20260525)

I_G_BITS = 1000

SCENARIOS = [
    {"label": r"Coarse $I_0$=100 $\rho_c$=0.010", "I0": 100, "lambda_factor": 2, "rho_c": 0.010, "cluster": 1.0},
    {"label": r"Medium $I_0$=50 $\rho_c$=0.006", "I0": 50, "lambda_factor": 2, "rho_c": 0.006, "cluster": 1.0},
    {"label": r"High $I_0$=20 $\rho_c$=0.003", "I0": 20, "lambda_factor": 2, "rho_c": 0.003, "cluster": 1.0},
    {"label": r"Extreme CNVS $I_0$=5 $\rho_c$=0.001", "I0": 5, "lambda_factor": 2, "rho_c": 0.001, "cluster": 1.0},
    {"label": r"Stress inference $I_0$=5 $\rho_c$=0.003 cluster=2", "I0": 5, "lambda_factor": 2, "rho_c": 0.003, "cluster": 2.0},
]

Q_VALUES = np.linspace(0.01, 0.99, 75)
TRIALS = 1200


def make_pool(M_fragments, lambda_factor):
    return np.repeat(np.arange(M_fragments), lambda_factor)


def validate_parameters(I_G, I0, lambda_factor, q, rho_c, cluster):
    if I_G <= 0:
        raise ValueError("I_G must be positive.")
    if I0 <= 0:
        raise ValueError("I0 must be positive.")
    if not isinstance(lambda_factor, int) or lambda_factor < 1:
        raise ValueError("lambda_factor must be an integer >= 1.")
    if not (0 <= q <= 1):
        raise ValueError("q must be in [0, 1].")
    if not (0 <= rho_c <= 1):
        raise ValueError("rho_c must be in [0, 1].")
    if cluster < 0:
        raise ValueError("cluster must be >= 0.")


def one_trial(I_G, I0, lambda_factor, q, rho_c, cluster):
    validate_parameters(I_G, I0, lambda_factor, q, rho_c, cluster)

    M = math.ceil(I_G / I0)
    pool = make_pool(M, lambda_factor)
    N_total = len(pool)

    captured_nodes = int(round(q * N_total))

    captured = rng.choice(
        pool,
        size=captured_nodes,
        replace=False
    )

    unique_fragments = len(np.unique(captured))
    missing_before = M - unique_fragments

    rho_eff = min(1.0, rho_c * cluster)

    P_infer = 1.0 - (1.0 - rho_eff) ** unique_fragments

    inferred = (
        rng.binomial(missing_before, P_infer)
        if missing_before > 0
        else 0
    )

    total_known = min(M, unique_fragments + inferred)

    K_direct = unique_fragments / M
    K_inferential = inferred / M
    K_adv = total_known / M

    missing_after = M - total_known
    H_res = missing_after * I0

    P_guess = 0.0 if H_res > 1024 else 2.0 ** (-H_res)

    complete_reconstruction = (total_known == M)
    entropy_success = (rng.random() < P_guess)

    attacker_wins = complete_reconstruction or entropy_success

    return {
        "K_direct": K_direct,
        "K_inferential": K_inferential,
        "K_adv": K_adv,
        "H_res": H_res,
        "P_infer": P_infer,
        "P_guess": P_guess,
        "attacker_wins": attacker_wins,
        "missing_after": missing_after,
        "M_fragments": M,
        "N_total": N_total,
    }


all_results = []

for scenario in SCENARIOS:
    label = scenario["label"]
    I0 = scenario["I0"]
    lambda_factor = scenario["lambda_factor"]
    rho_c = scenario["rho_c"]
    cluster = scenario["cluster"]

    for q in Q_VALUES:
        trials = []

        for _ in range(TRIALS):
            result = one_trial(
                I_G=I_G_BITS,
                I0=I0,
                lambda_factor=lambda_factor,
                q=q,
                rho_c=rho_c,
                cluster=cluster,
            )
            trials.append(result)

        df = pd.DataFrame(trials)

        all_results.append({
            "label": label,
            "q": q,
            "I0": I0,
            "lambda_factor": lambda_factor,
            "rho_c": rho_c,
            "cluster": cluster,
            "P_win": df["attacker_wins"].mean(),
            "mean_K_direct": df["K_direct"].mean(),
            "mean_K_inferential": df["K_inferential"].mean(),
            "mean_K_adv": df["K_adv"].mean(),
            "mean_H_res": df["H_res"].mean(),
            "mean_P_infer": df["P_infer"].mean(),
            "mean_P_guess": df["P_guess"].mean(),
            "mean_missing_after": df["missing_after"].mean(),
        })

results = pd.DataFrame(all_results)

print(results.head())


# ============================================================
# PLOT 1 — RECONSTRUCTION PROBABILITY
# ============================================================

plt.figure(figsize=(12, 7))

for label, g in results.groupby("label"):
    plt.plot(g["q"], g["P_win"], linewidth=2.4, label=label)

plt.axvline(x=1/3, color="black", linestyle="--", alpha=0.7, label="BFT reference line (1/3)")
plt.xlabel("Fraction of network physically compromised by attacker (q)")
plt.ylabel("Probability of complete unauthorized reconstruction")
plt.title(r"CNVS Extreme Monte Carlo Test 7 — $K_{adv}$ + Residual Shannon Entropy")
plt.grid(True, linestyle=":", alpha=0.7)
plt.legend(fontsize=8)
plt.tight_layout()
plt.show()


# ============================================================
# PLOT 2 — K_adv
# ============================================================

plt.figure(figsize=(12, 7))

for label, g in results.groupby("label"):
    plt.plot(g["q"], g["mean_K_adv"], linewidth=2.4, label=label)

plt.axvline(x=1/3, color="black", linestyle="--", alpha=0.7, label="BFT reference line (1/3)")
plt.xlabel("Fraction of network physically compromised by attacker (q)")
plt.ylabel(r"Mean adversarial knowledge $K_{adv}$")
plt.title(r"CNVS Extreme Monte Carlo Test 7 — Mean $K_{adv}$")
plt.grid(True, linestyle=":", alpha=0.7)
plt.legend(fontsize=8)
plt.tight_layout()
plt.show()


# ============================================================
# PLOT 3 — RESIDUAL ENTROPY
# ============================================================

plt.figure(figsize=(12, 7))

for label, g in results.groupby("label"):
    plt.plot(g["q"], g["mean_H_res"], linewidth=2.4, label=label)

plt.axvline(x=1/3, color="black", linestyle="--", alpha=0.7, label="BFT reference line (1/3)")
plt.xlabel("Fraction of network physically compromised by attacker (q)")
plt.ylabel(r"Mean residual entropy $H_{res}$ (bits)")
plt.title("CNVS Extreme Monte Carlo Test 7 — Residual Entropy")
plt.grid(True, linestyle=":", alpha=0.7)
plt.legend(loc="upper left", fontsize=8)
plt.tight_layout()
plt.show()